{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "JpFDzKcs4kP3"
   },
   "source": [
    "An example of measuring domain similarity for transfear learning via Earth Mover's Distance (EMD). The results correspond to part of Figure 5 in the original paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "collapsed": true,
    "id": "n1Ve6BcjGGfR"
   },
   "outputs": [],
   "source": [
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "%matplotlib inline\n",
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "from sklearn.metrics.pairwise import euclidean_distances\n",
    "import pyemd"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "TAsLMdpOg-Im"
   },
   "source": [
    "* Feature extraction on all datasets from a ResNet-101 pre-trained on JFT.\n",
    "* All features are pre-extracted in this example.\n",
    "* Notice that all features are extracted in the training set of each dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "height": 72
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 16418,
     "status": "ok",
     "timestamp": 1532561275077,
     "user": {
      "displayName": "Yin Cui",
      "photoUrl": "//lh3.googleusercontent.com/-g-Uzyho0wow/AAAAAAAAAAI/AAAAAAAAADY/zzsPhElO8S4/s50-c-k-no/photo.jpg",
      "userId": "112568385050425502553"
     },
     "user_tz": 240
    },
    "id": "165pufuXg8wy",
    "outputId": "5897e0a9-d4b4-4099-bde7-c49c3b39a8e2"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original feature shape: (5994, 2048)\n",
      "Number of classes: 200\n",
      "Feature per class shape: (200, 2048)\n"
     ]
    }
   ],
   "source": [
    "# In this example on CUB-200, we demonstrate how to calculate feature and weight for each class.\n",
    "feature_dir = './feature/resnet_101_JFT_299/'\n",
    "dataset = 'cub_200'\n",
    "\n",
    "# Load extracted features on CUB-200.\n",
    "feature = np.load(feature_dir + dataset + '_feature.npy')\n",
    "label = np.load(feature_dir + dataset + '_label.npy')\n",
    "\n",
    "# CUB-200 training set contains 5994 images from 200 classes, each image is \n",
    "# represented by a 2048-dimensional feature from the pre-trained ResNet-101.\n",
    "print('Original feature shape: (%d, %d)' % (feature.shape[0], feature.shape[1]))\n",
    "print('Number of classes: %d' % (len(np.unique(label))))\n",
    "\n",
    "# Calculate class feature as the averaged features among all images of the class.\n",
    "# Class weight is defined as the number of images of the class.\n",
    "sorted_label = sorted(list(set(label)))\n",
    "feature_per_class = np.zeros((len(sorted_label), 2048), dtype=np.float32)\n",
    "weight = np.zeros((len(sorted_label), ), dtype=np.float32)\n",
    "counter = 0\n",
    "for i in sorted_label:\n",
    "    idx = [(l==i) for l in label]\n",
    "    feature_per_class[counter, :] = np.mean(feature[idx, :], axis=0)\n",
    "    weight[counter] = np.sum(idx)\n",
    "    counter += 1\n",
    "\n",
    "print('Feature per class shape: (%d, %d)' % (feature_per_class.shape[0], \n",
    "                                             feature_per_class.shape[1]))\n",
    "\n",
    "np.save(feature_dir + dataset + '.npy', feature_per_class)\n",
    "np.save(feature_dir + dataset + '_weight.npy', weight)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "dQHzSchplpH4"
   },
   "source": [
    "*   Calculate feature per class and weight for all datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "height": 831
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 2344197,
     "status": "ok",
     "timestamp": 1532563619354,
     "user": {
      "displayName": "Yin Cui",
      "photoUrl": "//lh3.googleusercontent.com/-g-Uzyho0wow/AAAAAAAAAAI/AAAAAAAAADY/zzsPhElO8S4/s50-c-k-no/photo.jpg",
      "userId": "112568385050425502553"
     },
     "user_tz": 240
    },
    "id": "QqJS27R-lnqK",
    "outputId": "25315e76-5a86-4993-ea37-d0210a931864"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ImageNet_train --> cub_200\n",
      "EMD: 57.532    Domain Similarity: 0.563\n",
      "\n",
      "ImageNet_train --> flower_102\n",
      "EMD: 64.492    Domain Similarity: 0.525\n",
      "\n",
      "ImageNet_train --> stanford_cars\n",
      "EMD: 57.993    Domain Similarity: 0.560\n",
      "\n",
      "ImageNet_train --> stanford_dogs\n",
      "EMD: 48.044    Domain Similarity: 0.619\n",
      "\n",
      "ImageNet_train --> aircraft\n",
      "EMD: 58.717    Domain Similarity: 0.556\n",
      "\n",
      "ImageNet_train --> nabirds\n",
      "EMD: 55.595    Domain Similarity: 0.574\n",
      "\n",
      "ImageNet_train --> food_101\n",
      "EMD: 57.349    Domain Similarity: 0.564\n",
      "\n",
      "inat_train --> cub_200\n",
      "EMD: 42.961    Domain Similarity: 0.651\n",
      "\n",
      "inat_train --> flower_102\n",
      "EMD: 61.340    Domain Similarity: 0.542\n",
      "\n",
      "inat_train --> stanford_cars\n",
      "EMD: 62.488    Domain Similarity: 0.535\n",
      "\n",
      "inat_train --> stanford_dogs\n",
      "EMD: 55.880    Domain Similarity: 0.572\n",
      "\n",
      "inat_train --> aircraft\n",
      "EMD: 61.148    Domain Similarity: 0.543\n",
      "\n",
      "inat_train --> nabirds\n",
      "EMD: 37.703    Domain Similarity: 0.686\n",
      "\n",
      "inat_train --> food_101\n",
      "EMD: 62.587    Domain Similarity: 0.535\n",
      "\n",
      "ImageNet+inat --> cub_200\n",
      "EMD: 53.868    Domain Similarity: 0.584\n",
      "\n",
      "ImageNet+inat --> flower_102\n",
      "EMD: 63.508    Domain Similarity: 0.530\n",
      "\n",
      "ImageNet+inat --> stanford_cars\n",
      "EMD: 58.805    Domain Similarity: 0.555\n",
      "\n",
      "ImageNet+inat --> stanford_dogs\n",
      "EMD: 49.716    Domain Similarity: 0.608\n",
      "\n",
      "ImageNet+inat --> aircraft\n",
      "EMD: 59.176    Domain Similarity: 0.553\n",
      "\n",
      "ImageNet+inat --> nabirds\n",
      "EMD: 51.084    Domain Similarity: 0.600\n",
      "\n",
      "ImageNet+inat --> food_101\n",
      "EMD: 58.395    Domain Similarity: 0.558\n",
      "\n",
      "Elapsed time: 2515.816s\n"
     ]
    }
   ],
   "source": [
    "# Calculate domain similarity by Earth Mover's Distance (EMD).\n",
    "\n",
    "# Set minimum number of images per class for computational efficiency.\n",
    "# Classes in source domain with less than min_num_imgs images will be ignored.\n",
    "min_num_imgs = 200\n",
    "\n",
    "# Gamma for domain similarity: exp(-gamma x EMD)\n",
    "gamma = 0.01\n",
    "\n",
    "# Three source domain datasets: \n",
    "# ImageNet (ILSVRC 2012) training set,\n",
    "# iNaturalist 2017 training set (original training + 90% validation), \n",
    "# ImageNet + iNaturalist training set.\n",
    "source_domain = ['ImageNet_train', 'inat_train', 'ImageNet+inat']\n",
    "\n",
    "# Seven target domain datasets (all of them are from the training set):\n",
    "# CUB-200-2011 Bird, Oxford Flower 102, Stanford Car, Stanford Dog, \n",
    "# FGVC-Aircraft, NABirds, Food 101\n",
    "target_domain = ['cub_200', 'flower_102', 'stanford_cars', 'stanford_dogs', \n",
    "                 'aircraft', 'nabirds', 'food_101']\n",
    "\n",
    "# Create ImageNet + iNaturalist feature and weight by concatenation.\n",
    "f_1 = np.load(feature_dir + 'ImageNet_train' + '.npy')\n",
    "w_1 = np.load(feature_dir + 'ImageNet_train' + '_weight.npy')\n",
    "f_2 = np.load(feature_dir + 'inat_train' + '.npy')\n",
    "w_2 = np.load(feature_dir + 'inat_train' + '_weight.npy')\n",
    "f = np.append(f_1, f_2, axis=0)\n",
    "w = np.append(w_1, w_2, axis=0)\n",
    "np.save(feature_dir + 'ImageNet+inat.npy', f)\n",
    "np.save(feature_dir + 'ImageNet+inat_weight.npy', w)\n",
    "\n",
    "tic = time.time()\n",
    "for sd in source_domain:\n",
    "    for td in target_domain:\n",
    "        print('%s --> %s' % (sd, td))\n",
    "        f_s = np.load(feature_dir + sd + '.npy')\n",
    "        f_t = np.load(feature_dir + td + '.npy')\n",
    "        w_s = np.load(feature_dir + sd + '_weight.npy')\n",
    "        w_t = np.load(feature_dir + td + '_weight.npy')\n",
    "\n",
    "        # Remove source domain classes with number of images < 'min_num_imgs'.\n",
    "        idx = [i for i in range(len(w_s)) if w_s[i] >= min_num_imgs]\n",
    "        f_s = f_s[idx, :]\n",
    "        w_s = w_s[idx]\n",
    "\n",
    "        # Make sure two histograms have the same length and distance matrix is square.\n",
    "        data = np.float64(np.append(f_s, f_t, axis=0))\n",
    "        w_1 = np.zeros((len(w_s) + len(w_t),), np.float64)\n",
    "        w_2 = np.zeros((len(w_s) + len(w_t),), np.float64)\n",
    "        w_1[:len(w_s)] = w_s / np.sum(w_s)\n",
    "        w_2[len(w_s):] = w_t / np.sum(w_t)\n",
    "        D = euclidean_distances(data, data)\n",
    "\n",
    "        emd = pyemd.emd(np.float64(w_1), np.float64(w_2), np.float64(D))\n",
    "        print('EMD: %.3f    Domain Similarity: %.3f\\n' % (emd, np.exp(-gamma*emd)))\n",
    "print('Elapsed time: %.3fs' % (time.time() - tic))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "DomainSimilarity_EMD.ipynb",
   "provenance": [
    {
     "file_id": "1-Tah9eTTk4hYyoxPok2PJTSs4L0DJVAg",
     "timestamp": 1532563779271
    }
   ],
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
